# ------------------------------------------------------------------------------
# Plots distribution of physician complexity/representativeness bias
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 07_bias_yhats.sh
# ------------------------------------------------------------------------------

# Libraries --------------------------------------------------------------------
library(here)
library(yaml)
library(data.table)
library(tidyverse)
library(glue)
library(xtable)
library(OneR)

u <- modules::use(here("lib", "util.R"))
a <- modules::use(here("lib", "aesthetics.R"))
a$get_font("Optima", here("lib", "optima.ttf"))
temp <- here("code", "07_other_physician_errors", "temp")

overnight_lab <- ""

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here("lib", "filepaths.yml"))
risk_cutoff <- readRDS(paths$analysis$costeff_risk_cutoff)

# Representativeness Errors
message("Loading representative symptom risk cohort...")
risk_scores_represent <- readRDS(
  file.path(
    paths$modeling$dir, "prediction", "random", "represent",
    "scores_test_set.rds"
  )
) %>%
  setnames(
    "p__ensemble__stent_or_cabg_010_day__tested__represent",
    "yhat_represent"
  ) %>%
  select(ed_enc_id, yhat_represent)

lasso_predictions <- readRDS(
  here(
    "code", "06_physician_boundedness", "01_behavioral_lasso", "temp", "scores__lasso__stent_or_cabg_010_day.rds"
  )
)
yhat_simple <- filter(predictions, n_coef == 49)$score[[1]]

cohort <- readRDS(glue(paths$analysis$test_cohort)) %>%
  mutate(yhat_simple = yhat_simple) %>%
  filter(!exclude) %>%
  filter(!is.na(start_datetime)) %>%
  u$safe_left_join(risk_scores_represent) %>%
  u$safe_left_join(risk_scores_simple) %>%
  mutate(
    # ~90% of yhats are the median value, so this is a little less than top dec
    high_yhat_represent = yhat_represent > median(yhat_represent),
    yhat_diff_represent = yhat_represent - p__ensemble__stent_or_cabg_010_day__tested,
    tile_diff_represent = bin(yhat_diff_represent, nbins = 5, labels = 1:5, method = "content"),
    yhat_diff_simple = yhat_simple - p__ensemble__stent_or_cabg_010_day__tested,
    tile_diff_simple = bin(yhat_diff_simple, nbins = 5, labels = 1:5, method = "content"),
    tile_simple = bin(yhat_simple, nbins = 5, labels = 1:5, method = "content"),
    high_risk_untested = tile_stent_or_cabg_005_tested == 5 & !test_010_day,
    low_risk_tested = (p__ensemble__stent_or_cabg_010_day__tested < risk_cutoff) & test_010_day
  )
pctile_represent <- ecdf(cohort$yhat_represent)
pctile_simple <- ecdf(cohort$yhat_simple)

cohort <- cohort %>%
  mutate(
    # pctile_yhat_represenct = pctile_represent(yhat_represent),
    pctile_yhat_simple = pctile_simple(yhat_simple),
    pctile_diff_simple = pctile_yhat_simple - tile_stent_or_cabg_100_tested,
    tile_pctile_diff_simple = bin(pctile_diff_simple, nbins = 5, labels = 1:5, method = "content")
  )

total_high_risk_untested = sum(cohort$high_risk_untested)
total_low_risk_tested = sum(cohort$low_risk_tested)

tiles <- c(
  "tile_diff_represent", "tile_diff_simple",
  "tile_simple"
)
for(tile in tiles){
  df <- copy(cohort) %>%
    setnames(tile, "tile")

  message("Tile Var: ", tile)
  rates <- df %>%
    group_by(tile) %>%
    summarize(
      n = n(),
      mean_diff_simple = mean(yhat_diff_simple),
      n_high_risk_untested = sum(high_risk_untested),
      n_low_risk_tested = sum(low_risk_tested),
      pct_of_high_risk_untested = n_high_risk_untested/total_high_risk_untested,
      pct_of_low_risk_tested = n_low_risk_tested/total_low_risk_tested
    ) %>%
    ungroup

  print(select(rates, tile, n, mean_diff_simple, pct_of_high_risk_untested, pct_of_low_risk_tested))
}

# Simple-Full Plot -------------------------------------------------------------
message("Plotting high-risk untested/low-risk tested for simple diff yhat...")
simple_full_rates <- df %>%
  group_by(tile_diff_simple) %>%
  summarize(
    n = n(),
    n_high_risk_untested = sum(high_risk_untested),
    n_low_risk_tested = sum(low_risk_tested),
    pct_of_high_risk_untested = n_high_risk_untested/total_high_risk_untested,
    pct_of_low_risk_tested = n_low_risk_tested/total_low_risk_tested,
  ) %>%
  ungroup

clustered_SEs_hru <- c()
clustered_SEs_lrt <- c()
for(tile in 1:5){
  df$in_tile <- df$tile_diff_simple == tile
  high_risk_untested_se <- u$clustered_se(
    data = filter(df, high_risk_untested),
    obs_col_name = "in_tile",
    cluster_by_col_name = "ptid"
  )
  low_risk_tested_se <- u$clustered_se(
    data = filter(df, low_risk_tested),
    obs_col_name = "in_tile",
    cluster_by_col_name = "ptid"
  )
  clustered_SEs_hru <- c(clustered_SEs_hru, high_risk_untested_se)
  clustered_SEs_lrt <- c(clustered_SEs_lrt, low_risk_tested_se)
}

simple_full_rates$high_risk_unt_se <- clustered_SEs_hru
simple_full_rates$low_risk_t_se <- clustered_SEs_lrt

# reshape
SEs <- simple_full_rates %>%
  select(tile_diff_simple, high_risk_unt_se, low_risk_t_se) %>%
  setnames(
    c("high_risk_unt_se", "low_risk_t_se"),
    c("high_risk_unt", "low_risk_t")
  ) %>%
  gather(key = "grouping", value = "SE", high_risk_unt, low_risk_t)
rates <- simple_full_rates %>%
  select(tile_diff_simple, pct_of_high_risk_untested, pct_of_low_risk_tested) %>%
  setnames(
    c("pct_of_high_risk_untested", "pct_of_low_risk_tested"),
    c("high_risk_unt", "low_risk_t")
  ) %>%
  gather(key = "grouping", value = "beta", high_risk_unt, low_risk_t) %>%
  u$safe_left_join(SEs) %>%
  mutate(
    beta_lo = beta - 1.96*SE,
    beta_hi = beta + 1.96*SE
  ) %>%
  mutate(x_var = as.numeric(tile_diff_simple)) %>%
  mutate(grouping = factor(grouping, labels = c("High Risk Untested", "Low Risk Tested")))

gg <- a$tile_plot(rates) +
  labs(
    x = "Percentile of (Simple - Full) Predicted Risk",
    y = "Fraction",
    color = "Patient Group",
    fill = "Patient Group"
  )

message("Saving...")
ggsave(
  file.path(temp, "simplicity_bias_rates.png"), width = 10,
  height = 7, unit = "in"
)
write_csv(rates, file.path(temp, "simplicity_bias_quintile_rates.csv"))

# Just HR-Untested
hr_df <- filter(rates, grouping == "High Risk Untested") %>%
  mutate(x_var = factor(x_var, labels = seq(20,100,20)))
gg_hr <- ggplot(
  hr_df, aes(
    x = x_var, y = beta
  )
) +
  labs(
    x = "Percentile of (Simple - Full) Predicted Risk",
    y = "Fraction"
  ) +
  geom_bar(stat = "identity", color = "black", fill = a$disc_palette[2], width = 1) +
  geom_errorbar(aes(ymin = beta_lo, ymax = beta_hi), width = 0.2) +
  theme_bw() +
  theme(
    legend.position = "none",
    text = element_text(family = "Optima", size = 60)
  ) +

message("Saving high-risk untested simplicity bias rates...")
ggsave(
  file.path(temp, "simplicity_bias_HR_rates.png"),
  width = 10, height = 7, unit = "in"
)

# Just LR-Tested
lr_df <- filter(rates, grouping == "Low Risk Tested") %>%
  mutate(x_var = factor(x_var, labels = seq(20,100,20)))
gg_hr <- ggplot(
  lr_df, aes(
    x = x_var, y = beta
  )
) +
  labs(
    x = "Percentile of (Simple - Full) Predicted Risk",
    y = "Fraction"
  ) +
  geom_bar(stat = "identity", color = "black", fill = a$disc_palette[1], width = 1) +
  geom_errorbar(aes(ymin = beta_lo, ymax = beta_hi), width = 0.2) +
  theme_bw() +
  theme(
    legend.position = "none",
    text = element_text(family = "Optima", size = 60)
  ) +

message("Saving low-risk tested simplicity bias rates...")
ggsave(
  file.path(temp, "simplicity_bias_LR_rates.png"),
  width = 10, height = 7, unit = "in"
)

# Representativeness Bias ------------------------------------------------------
message("Plotting high-risk untested/low-risk tested for simple diff yhat...")
representative_bias_rates <- df %>%
  group_by(tile_diff_represent) %>%
  summarize(
    n = n(),
    n_high_risk_untested = sum(high_risk_untested),
    n_low_risk_tested = sum(low_risk_tested),
    pct_of_high_risk_untested = n_high_risk_untested/total_high_risk_untested,
    pct_of_low_risk_tested = n_low_risk_tested/total_low_risk_tested,
  ) %>%
  ungroup

clustered_SEs_hru <- c()
clustered_SEs_lrt <- c()
for(tile in 1:5){
  df$in_tile <- df$tile_diff_represent == tile
  high_risk_untested_se <- u$clustered_se(
    data = filter(df, high_risk_untested),
    obs_col_name = "in_tile",
    cluster_by_col_name = "ptid"
  )
  low_risk_tested_se <- u$clustered_se(
    data = filter(df, low_risk_tested),
    obs_col_name = "in_tile",
    cluster_by_col_name = "ptid"
  )
  clustered_SEs_hru <- c(clustered_SEs_hru, high_risk_untested_se)
  clustered_SEs_lrt <- c(clustered_SEs_lrt, low_risk_tested_se)
}

representative_bias_rates$high_risk_unt_se <- clustered_SEs_hru
representative_bias_rates$low_risk_t_se <- clustered_SEs_lrt

# reshape
SEs <- representative_bias_rates %>%
  select(tile_diff_represent, high_risk_unt_se, low_risk_t_se) %>%
  setnames(
    c("high_risk_unt_se", "low_risk_t_se"),
    c("high_risk_unt", "low_risk_t")
  ) %>%
  gather(key = "grouping", value = "SE", high_risk_unt, low_risk_t)
rates <- representative_bias_rates %>%
  select(tile_diff_represent, pct_of_high_risk_untested, pct_of_low_risk_tested) %>%
  setnames(
    c("pct_of_high_risk_untested", "pct_of_low_risk_tested"),
    c("high_risk_unt", "low_risk_t")
  ) %>%
  gather(key = "grouping", value = "beta", high_risk_unt, low_risk_t) %>%
  u$safe_left_join(SEs) %>%
  mutate(
    beta_lo = beta - 1.96*SE,
    beta_hi = beta + 1.96*SE
  ) %>%
  mutate(x_var = as.numeric(tile_diff_represent)) %>%
  mutate(grouping = factor(grouping, labels = c("High Risk Untested", "Low Risk Tested")))

gg <- a$tile_plot(rates) +
  labs(
    x = "Percentile of (Simple - Full) Predicted Risk",
    y = "Fraction",
    color = "Patient Group",
    fill = "Patient Group"
  )

message("Saving...")
ggsave(
  file.path(temp, "simplicity_bias_rates.png"), width = 10,
  height = 7, unit = "in"
)
write_csv(rates, file.path(temp, "representative_bias_quintile_rates.csv"))

# Just HR-Untested
hr_df <- filter(rates, grouping == "High Risk Untested") %>%
  mutate(x_var = factor(x_var, labels = seq(20,100,20)))
gg_hr <- ggplot(
  hr_df, aes(
    x = x_var, y = beta
  )
) +
  labs(
    x = "Percentile of (Representative - Full) Predicted Risk",
    y = "Fraction"
  ) +
  geom_bar(stat = "identity", color = "black", fill = a$disc_palette[2], width = 1) +
  geom_errorbar(aes(ymin = beta_lo, ymax = beta_hi), width = 0.2) +
  theme_bw() +
  theme(
    legend.position = "none",
    text = element_text(family = "Optima", size = 60)
  ) +

message("Saving high-risk untested representative bias rates...")
ggsave(
  file.path(temp, "representative_bias_HR_rates.png"),
  width = 10, height = 7, unit = "in"
)

# Just LR-Tested
lr_df <- filter(rates, grouping == "Low Risk Tested") %>%
  mutate(x_var = factor(x_var, labels = seq(20,100,20)))
gg_hr <- ggplot(
  lr_df, aes(
    x = x_var, y = beta
  )
) +
  labs(
    x = "Percentile of (Representative - Full) Predicted Risk",
    y = "Fraction"
  ) +
  geom_bar(stat = "identity", color = "black", fill = a$disc_palette[1], width = 1) +
  geom_errorbar(aes(ymin = beta_lo, ymax = beta_hi), width = 0.2) +
  theme_bw() +
  theme(
    legend.position = "none",
    text = element_text(family = "Optima", size = 60)
  ) +

message("Saving low-risk tested representative bias rates...")
ggsave(
  file.path(temp, "representative_bias_LR_rates.png"),
  width = 10, height = 7, unit = "in"
)

message("Done.")
